{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "3c5ded58",
   "metadata": {},
   "source": [
    "# 3-3,高阶API示范\n",
    "\n",
    "Pytorch没有官方的高阶API，一般需要用户自己实现训练循环、验证循环、和预测循环。\n",
    "\n",
    "作者通过仿照keras的功能对Pytorch的nn.Module进行了封装，设计了torchkeras.KerasModel类，\n",
    "\n",
    "实现了 fit, evaluate，predict等方法，相当于用户自定义高阶API。\n",
    "\n",
    "并示范了用它实现线性回归模型。\n",
    "\n",
    "此外，作者还通过借用pytorch_lightning的功能，封装了类Keras接口的另外一种实现，即torchkeras.LightModel类。\n",
    "\n",
    "并示范了用它实现DNN二分类模型。\n",
    "\n",
    "\n",
    "torchkeras.KerasModel类和torchkeras.LightModel类看起来非常强大，但实际上它们的源码非常简单，不足200行。\n",
    "我们在第一章中`一、Pytorch的建模流程`用到的训练代码其实就是torchkeras库的核心源码。\n",
    "\n",
    "在实际应用中，由于有些模型的输入输出以及Loss结构和torchkeras的假设结构有所不同，直接调用torchkeras可能不能满足需求，不要害怕，copy出来\n",
    "torchkeras.KerasModel或者torchkeras.LightModel的源码，在输入输出和Loss上简单改动一下就可以。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d047412a",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a469f8d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import datetime\n",
    "\n",
    "#打印时间\n",
    "def printbar():\n",
    "    nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')\n",
    "    print(\"\\n\"+\"==========\"*8 + \"%s\"%nowtime)\n",
    "\n",
    "#mac系统上pytorch和matplotlib在jupyter中同时跑需要更改环境变量\n",
    "os.environ[\"KMP_DUPLICATE_LIB_OK\"]=\"TRUE\" \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db03c304",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install torchkeras==3.2.2 -i https://pypi.python.org/simple"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "932c81f4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch \n",
    "import torchkeras \n",
    "\n",
    "\n",
    "print(\"torch.__version__=\"+torch.__version__) \n",
    "print(\"torchkeras.__version__=\"+torchkeras.__version__) "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9ad9ac52",
   "metadata": {},
   "source": [
    "```\n",
    "torch.__version__=1.10.0\n",
    "torchkeras.__version__=3.2.2\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3865cdfe",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "fda456d9",
   "metadata": {},
   "source": [
    "### 一，线性回归模型"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cf19d513",
   "metadata": {},
   "source": [
    "此范例我们通过继承torchkeras.Model模型接口，实现线性回归模型。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9c5f2e29",
   "metadata": {},
   "source": [
    "**1，准备数据**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "36470b05",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np \n",
    "import pandas as pd\n",
    "from matplotlib import pyplot as plt \n",
    "import torch\n",
    "from torch import nn\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.data import Dataset,DataLoader,TensorDataset\n",
    "\n",
    "#样本数量\n",
    "n = 400\n",
    "\n",
    "# 生成测试用数据集\n",
    "X = 10*torch.rand([n,2])-5.0  #torch.rand是均匀分布 \n",
    "w0 = torch.tensor([[2.0],[-3.0]])\n",
    "b0 = torch.tensor([[10.0]])\n",
    "Y = X@w0 + b0 + torch.normal( 0.0,2.0,size = [n,1])  # @表示矩阵乘法,增加正态扰动\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "abf78477",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 数据可视化\n",
    "\n",
    "%matplotlib inline\n",
    "%config InlineBackend.figure_format = 'svg'\n",
    "\n",
    "plt.figure(figsize = (12,5))\n",
    "ax1 = plt.subplot(121)\n",
    "ax1.scatter(X[:,0],Y[:,0], c = \"b\",label = \"samples\")\n",
    "ax1.legend()\n",
    "plt.xlabel(\"x1\")\n",
    "plt.ylabel(\"y\",rotation = 0)\n",
    "\n",
    "ax2 = plt.subplot(122)\n",
    "ax2.scatter(X[:,1],Y[:,0], c = \"g\",label = \"samples\")\n",
    "ax2.legend()\n",
    "plt.xlabel(\"x2\")\n",
    "plt.ylabel(\"y\",rotation = 0)\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "51121814",
   "metadata": {},
   "source": [
    "![](./data/3-3-回归数据可视化.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ccb5ef52",
   "metadata": {},
   "outputs": [],
   "source": [
    "#构建输入数据管道\n",
    "ds = TensorDataset(X,Y)\n",
    "ds_train,ds_val = torch.utils.data.random_split(ds,[int(400*0.7),400-int(400*0.7)])\n",
    "dl_train = DataLoader(ds_train,batch_size = 10,shuffle=True,num_workers=2)\n",
    "dl_val = DataLoader(ds_val,batch_size = 10,num_workers=2)\n",
    "\n",
    "features,labels = next(iter(dl_train))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "59cf0caf",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "b932d8b9",
   "metadata": {},
   "source": [
    "**2，定义模型**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "55605afc",
   "metadata": {},
   "outputs": [],
   "source": [
    "class LinearRegression(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(LinearRegression, self).__init__()\n",
    "        self.fc = nn.Linear(2,1)\n",
    "    \n",
    "    def forward(self,x):\n",
    "        return self.fc(x)\n",
    "\n",
    "net = LinearRegression()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c3e2474e",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "from torchkeras import summary \n",
    "\n",
    "summary(net,input_data=features);"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d95bac43",
   "metadata": {},
   "source": [
    "```\n",
    "--------------------------------------------------------------------------\n",
    "Layer (type)                            Output Shape              Param #\n",
    "==========================================================================\n",
    "Linear-1                                     [-1, 1]                    3\n",
    "==========================================================================\n",
    "Total params: 3\n",
    "Trainable params: 3\n",
    "Non-trainable params: 0\n",
    "--------------------------------------------------------------------------\n",
    "Input size (MB): 0.000069\n",
    "Forward/backward pass size (MB): 0.000008\n",
    "Params size (MB): 0.000011\n",
    "Estimated Total Size (MB): 0.000088\n",
    "--------------------------------------------------------------------------\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d401650b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "e69af204",
   "metadata": {},
   "source": [
    "**3，训练模型**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "261f8aaa",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchkeras import KerasModel \n",
    "\n",
    "import torchmetrics\n",
    "\n",
    "net = LinearRegression()\n",
    "model = KerasModel(net=net,\n",
    "                   loss_fn = nn.MSELoss(),\n",
    "                   metrics_dict = {\"mae\":torchmetrics.MeanAbsoluteError()},\n",
    "                   optimizer= torch.optim.Adam(net.parameters(),lr = 0.05))\n",
    "\n",
    "dfhistory = model.fit(train_data=dl_train,\n",
    "      val_data=dl_val,\n",
    "      epochs=20,\n",
    "      ckpt_path='checkpoint.pt',\n",
    "      patience=5,\n",
    "      monitor='val_loss',\n",
    "      mode='min')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b0c5864",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 结果可视化\n",
    "\n",
    "%matplotlib inline\n",
    "%config InlineBackend.figure_format = 'svg'\n",
    "\n",
    "w,b = net.state_dict()[\"fc.weight\"],net.state_dict()[\"fc.bias\"]\n",
    "\n",
    "plt.figure(figsize = (12,5))\n",
    "ax1 = plt.subplot(121)\n",
    "ax1.scatter(X[:,0],Y[:,0], c = \"b\",label = \"samples\")\n",
    "ax1.plot(X[:,0],w[0,0]*X[:,0]+b[0],\"-r\",linewidth = 5.0,label = \"model\")\n",
    "ax1.legend()\n",
    "plt.xlabel(\"x1\")\n",
    "plt.ylabel(\"y\",rotation = 0)\n",
    "\n",
    "\n",
    "ax2 = plt.subplot(122)\n",
    "ax2.scatter(X[:,1],Y[:,0], c = \"g\",label = \"samples\")\n",
    "ax2.plot(X[:,1],w[0,1]*X[:,1]+b[0],\"-r\",linewidth = 5.0,label = \"model\")\n",
    "ax2.legend()\n",
    "plt.xlabel(\"x2\")\n",
    "plt.ylabel(\"y\",rotation = 0)\n",
    "\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1f412f0f",
   "metadata": {},
   "source": [
    "**4，评估模型**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5380c5e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "dfhistory.tail()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b69d7227",
   "metadata": {},
   "source": [
    "```\n",
    "\ttrain_loss\ttrain_mae\tval_loss\tval_mae\tepoch\n",
    "15\t4.339620\t1.635648\t3.119237\t1.384351\t16\n",
    "16\t4.313104\t1.631849\t2.999482\t1.352427\t17\n",
    "17\t4.319547\t1.628811\t3.022779\t1.355054\t18\n",
    "18\t4.315403\t1.636815\t3.087339\t1.369488\t19\n",
    "19\t4.266822\t1.627701\t2.937751\t1.330670\t20\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f42e3a0",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "18806439",
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "%config InlineBackend.figure_format = 'svg'\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "def plot_metric(dfhistory, metric):\n",
    "    train_metrics = dfhistory[\"train_\"+metric]\n",
    "    val_metrics = dfhistory['val_'+metric]\n",
    "    epochs = range(1, len(train_metrics) + 1)\n",
    "    plt.plot(epochs, train_metrics, 'bo--')\n",
    "    plt.plot(epochs, val_metrics, 'ro-')\n",
    "    plt.title('Training and validation '+ metric)\n",
    "    plt.xlabel(\"Epochs\")\n",
    "    plt.ylabel(metric)\n",
    "    plt.legend([\"train_\"+metric, 'val_'+metric])\n",
    "    plt.show()\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "31b33440",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_metric(dfhistory,\"loss\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b1b4b3c3",
   "metadata": {},
   "source": [
    "![](./data/3-3-loss曲线.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "31a7c0af",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_metric(dfhistory,\"mae\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9fa665e8",
   "metadata": {},
   "source": [
    "![](./data/3-3-mape曲线.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1578280e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 评估\n",
    "model.evaluate(dl_val)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b2ff4283",
   "metadata": {},
   "source": [
    "```\n",
    "{'val_loss': 2.9377514322598777, 'val_mae': 1.3306695222854614}\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7972719b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "e1d7ab0a",
   "metadata": {},
   "source": [
    "**5，使用模型**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3cf065a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 预测\n",
    "dl = DataLoader(TensorDataset(X))\n",
    "model.predict(dl)[0:10]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "89997a39",
   "metadata": {},
   "source": [
    "```\n",
    "tensor([[ 3.9212],\n",
    "        [ 8.6336],\n",
    "        [ 6.1982],\n",
    "        [ 6.1212],\n",
    "        [-5.0974],\n",
    "        [-6.3183],\n",
    "        [ 4.6588],\n",
    "        [ 5.5349],\n",
    "        [11.9106],\n",
    "        [24.6937]])\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e517835",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 预测\n",
    "model.predict(dl_val)[0:10]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "df8afc8b",
   "metadata": {},
   "source": [
    "```\n",
    "tensor([[-11.0324],\n",
    "        [ 26.2708],\n",
    "        [ 24.8866],\n",
    "        [ 12.2698],\n",
    "        [-12.0984],\n",
    "        [  6.7254],\n",
    "        [ 12.8081],\n",
    "        [ 20.6977],\n",
    "        [  5.4715],\n",
    "        [  2.0188]])\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21f0f0f5",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "4a4d4866",
   "metadata": {},
   "source": [
    "### 二，DNN二分类模型"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "010aac54",
   "metadata": {},
   "source": [
    "此范例我们通过继承torchkeras.LightModel模型接口，实现DNN二分类模型。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "23026251",
   "metadata": {},
   "source": [
    "**1，准备数据**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e02230f7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np \n",
    "import pandas as pd \n",
    "from matplotlib import pyplot as plt\n",
    "import torch\n",
    "from torch import nn\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.data import Dataset,DataLoader,TensorDataset\n",
    "import torchkeras \n",
    "import pytorch_lightning as pl \n",
    "%matplotlib inline\n",
    "%config InlineBackend.figure_format = 'svg'\n",
    "\n",
    "#正负样本数量\n",
    "n_positive,n_negative = 2000,2000\n",
    "\n",
    "#生成正样本, 小圆环分布\n",
    "r_p = 5.0 + torch.normal(0.0,1.0,size = [n_positive,1]) \n",
    "theta_p = 2*np.pi*torch.rand([n_positive,1])\n",
    "Xp = torch.cat([r_p*torch.cos(theta_p),r_p*torch.sin(theta_p)],axis = 1)\n",
    "Yp = torch.ones_like(r_p)\n",
    "\n",
    "#生成负样本, 大圆环分布\n",
    "r_n = 8.0 + torch.normal(0.0,1.0,size = [n_negative,1]) \n",
    "theta_n = 2*np.pi*torch.rand([n_negative,1])\n",
    "Xn = torch.cat([r_n*torch.cos(theta_n),r_n*torch.sin(theta_n)],axis = 1)\n",
    "Yn = torch.zeros_like(r_n)\n",
    "\n",
    "#汇总样本\n",
    "X = torch.cat([Xp,Xn],axis = 0)\n",
    "Y = torch.cat([Yp,Yn],axis = 0)\n",
    "\n",
    "\n",
    "#可视化\n",
    "plt.figure(figsize = (6,6))\n",
    "plt.scatter(Xp[:,0],Xp[:,1],c = \"r\")\n",
    "plt.scatter(Xn[:,0],Xn[:,1],c = \"g\")\n",
    "plt.legend([\"positive\",\"negative\"]);\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a290d8f7",
   "metadata": {},
   "source": [
    "![](./data/3-3-分类数据可视化.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "75547978",
   "metadata": {},
   "outputs": [],
   "source": [
    "ds = TensorDataset(X,Y)\n",
    "\n",
    "ds_train,ds_val = torch.utils.data.random_split(ds,[int(len(ds)*0.7),len(ds)-int(len(ds)*0.7)])\n",
    "dl_train = DataLoader(ds_train,batch_size = 100,shuffle=True,num_workers=2)\n",
    "dl_val = DataLoader(ds_val,batch_size = 100,num_workers=2)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "48f4a5c4",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "73579b1b",
   "metadata": {},
   "source": [
    "**2，定义模型**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "881fd06d",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.fc1 = nn.Linear(2,4)\n",
    "        self.fc2 = nn.Linear(4,8) \n",
    "        self.fc3 = nn.Linear(8,1)\n",
    "        \n",
    "    def forward(self,x):\n",
    "        x = F.relu(self.fc1(x))\n",
    "        x = F.relu(self.fc2(x))\n",
    "        y = self.fc3(x)\n",
    "        return y\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0df9c08b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchkeras \n",
    "from torchkeras.metrics import Accuracy \n",
    "\n",
    "net = Net()\n",
    "loss_fn = nn.BCEWithLogitsLoss()\n",
    "metric_dict = {\"acc\":Accuracy()}\n",
    "\n",
    "optimizer = torch.optim.Adam(net.parameters(), lr=0.05)\n",
    "lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.0001)\n",
    "\n",
    "model = torchkeras.LightModel(net,\n",
    "                   loss_fn = loss_fn,\n",
    "                   metrics_dict= metric_dict,\n",
    "                   optimizer = optimizer,\n",
    "                   lr_scheduler = lr_scheduler,\n",
    "                  )       \n",
    "\n",
    "from torchkeras import summary\n",
    "summary(model,input_data=features);\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4366f02f",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "2db246f6",
   "metadata": {},
   "source": [
    "**3，训练模型**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6c4ef10f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pytorch_lightning as pl     \n",
    "\n",
    "#1，设置回调函数\n",
    "model_ckpt = pl.callbacks.ModelCheckpoint(\n",
    "    monitor='val_acc',\n",
    "    save_top_k=1,\n",
    "    mode='max'\n",
    ")\n",
    "\n",
    "early_stopping = pl.callbacks.EarlyStopping(monitor = 'val_acc',\n",
    "                           patience=3,\n",
    "                           mode = 'max'\n",
    "                          )\n",
    "\n",
    "#2，设置训练参数\n",
    "\n",
    "# gpus=0 则使用cpu训练，gpus=1则使用1个gpu训练，gpus=2则使用2个gpu训练，gpus=-1则使用所有gpu训练，\n",
    "# gpus=[0,1]则指定使用0号和1号gpu训练， gpus=\"0,1,2,3\"则使用0,1,2,3号gpu训练\n",
    "# tpus=1 则使用1个tpu训练\n",
    "trainer = pl.Trainer(logger=True,\n",
    "                     min_epochs=3,max_epochs=20,\n",
    "                     gpus=0,\n",
    "                     callbacks = [model_ckpt,early_stopping],\n",
    "                     enable_progress_bar = True) \n",
    "\n",
    "\n",
    "##4，启动训练循环\n",
    "trainer.fit(model,dl_train,dl_val)\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fd876583",
   "metadata": {},
   "source": [
    "```\n",
    "================================================================================2022-07-16 20:25:49\n",
    "{'epoch': 0, 'val_loss': 0.3484574854373932, 'val_acc': 0.8766666650772095}\n",
    "{'epoch': 0, 'train_loss': 0.5639581680297852, 'train_acc': 0.708214282989502}\n",
    "<<<<<< reach best val_acc : 0.8766666650772095 >>>>>>\n",
    "\n",
    "================================================================================2022-07-16 20:25:54\n",
    "{'epoch': 1, 'val_loss': 0.18654096126556396, 'val_acc': 0.925000011920929}\n",
    "{'epoch': 1, 'train_loss': 0.2512527406215668, 'train_acc': 0.9117857217788696}\n",
    "<<<<<< reach best val_acc : 0.925000011920929 >>>>>>\n",
    "\n",
    "================================================================================2022-07-16 20:25:59\n",
    "{'epoch': 2, 'val_loss': 0.19609291851520538, 'val_acc': 0.9191666841506958}\n",
    "{'epoch': 2, 'train_loss': 0.19115397334098816, 'train_acc': 0.9257143139839172}\n",
    "\n",
    "================================================================================2022-07-16 20:26:04\n",
    "{'epoch': 3, 'val_loss': 0.18749761581420898, 'val_acc': 0.925000011920929}\n",
    "{'epoch': 3, 'train_loss': 0.19545568525791168, 'train_acc': 0.9235714077949524}\n",
    "\n",
    "================================================================================2022-07-16 20:26:09\n",
    "{'epoch': 4, 'val_loss': 0.21518440544605255, 'val_acc': 0.9083333611488342}\n",
    "{'epoch': 4, 'train_loss': 0.1998639553785324, 'train_acc': 0.9192857146263123}\n",
    "```\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "08ca32a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 结果可视化\n",
    "fig, (ax1,ax2) = plt.subplots(nrows=1,ncols=2,figsize = (12,5))\n",
    "ax1.scatter(Xp[:,0],Xp[:,1], c=\"r\")\n",
    "ax1.scatter(Xn[:,0],Xn[:,1],c = \"g\")\n",
    "ax1.legend([\"positive\",\"negative\"]);\n",
    "ax1.set_title(\"y_true\");\n",
    "\n",
    "Xp_pred = X[torch.squeeze(net.forward(X)>=0.5)]\n",
    "Xn_pred = X[torch.squeeze(net.forward(X)<0.5)]\n",
    "\n",
    "ax2.scatter(Xp_pred[:,0],Xp_pred[:,1],c = \"r\")\n",
    "ax2.scatter(Xn_pred[:,0],Xn_pred[:,1],c = \"g\")\n",
    "ax2.legend([\"positive\",\"negative\"]);\n",
    "ax2.set_title(\"y_pred\");\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5580f23e",
   "metadata": {},
   "source": [
    "![](./data/3-3-分类结果可视化.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e5bb9c8d",
   "metadata": {},
   "source": [
    "**4，评估模型**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d817e473",
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "%config InlineBackend.figure_format = 'svg'\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "def plot_metric(dfhistory, metric):\n",
    "    train_metrics = dfhistory[\"train_\"+metric]\n",
    "    val_metrics = dfhistory['val_'+metric]\n",
    "    epochs = range(1, len(train_metrics) + 1)\n",
    "    plt.plot(epochs, train_metrics, 'bo--')\n",
    "    plt.plot(epochs, val_metrics, 'ro-')\n",
    "    plt.title('Training and validation '+ metric)\n",
    "    plt.xlabel(\"Epochs\")\n",
    "    plt.ylabel(metric)\n",
    "    plt.legend([\"train_\"+metric, 'val_'+metric])\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "abdb3627",
   "metadata": {},
   "outputs": [],
   "source": [
    "dfhistory  = model.get_history() \n",
    "plot_metric(dfhistory,\"loss\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1a0095c1",
   "metadata": {},
   "source": [
    "![](https://tva1.sinaimg.cn/large/e6c9d24egy1h491k7wtl0j20f70a6wes.jpg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "709dc9ee",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a74c677b",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_metric(dfhistory,\"acc\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b91fe873",
   "metadata": {},
   "source": [
    "![](https://tva1.sinaimg.cn/large/e6c9d24egy1h491k8if3hj20ev0aaglw.jpg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0517f63a",
   "metadata": {},
   "outputs": [],
   "source": [
    "#使用最佳保存点进行评估\n",
    "trainer.test(ckpt_path='best', dataloaders=dl_val,verbose = False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0174ba79",
   "metadata": {},
   "source": [
    "```\n",
    "{'test_loss': 0.18654096126556396, 'test_acc': 0.925000011920929}\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c96fb7e1",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "92d14a17",
   "metadata": {},
   "source": [
    "**5，使用模型**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eda36b16",
   "metadata": {},
   "outputs": [],
   "source": [
    "predictions = F.sigmoid(torch.cat(trainer.predict(model, dl_val, ckpt_path='best'))) \n",
    "predictions "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "de9c34ba",
   "metadata": {},
   "source": [
    "```\n",
    "tensor([[0.3873],\n",
    "        [0.0028],\n",
    "        [0.8772],\n",
    "        ...,\n",
    "        [0.9886],\n",
    "        [0.4970],\n",
    "        [0.8094]])\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa29f557",
   "metadata": {},
   "outputs": [],
   "source": [
    "def predict(model,dl):\n",
    "    model.eval()\n",
    "    result = torch.cat([model.forward(t[0]) for t in dl])\n",
    "    return(result.data)\n",
    "\n",
    "print(model.device)\n",
    "predictions = F.sigmoid(predict(model,dl_val)[:10]) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0116222c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "357d80cf",
   "metadata": {},
   "source": [
    "**6，保存模型**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f0435d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(trainer.checkpoint_callback.best_model_path)\n",
    "print(trainer.checkpoint_callback.best_model_score)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1570f270",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_loaded = torchkeras.LightModel.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "08737fd8",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8dc4bc27",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "1cde75b6",
   "metadata": {},
   "source": [
    "**如果本书对你有所帮助，想鼓励一下作者，记得给本项目加一颗星星star⭐️，并分享给你的朋友们喔😊!** \n",
    "\n",
    "如果对本书内容理解上有需要进一步和作者交流的地方，欢迎在公众号\"算法美食屋\"下留言。作者时间和精力有限，会酌情予以回复。\n",
    "\n",
    "也可以在公众号后台回复关键字：**加群**，加入读者交流群和大家讨论。\n",
    "\n",
    "![算法美食屋logo.png](https://tva1.sinaimg.cn/large/e6c9d24egy1h41m2zugguj20k00b9q46.jpg)"
   ]
  }
 ],
 "metadata": {
  "jupytext": {
   "cell_metadata_filter": "-all",
   "main_language": "python"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
